"""
Phase 8: Referee Revision Analyses
Development Threshold Paper

1. Distance-to-exit control (mechanical crossing concern)
2. Savings → crossing bridging regression
3. Schoenfeld residual test for Cox PH assumption
4. Competing risk justification (fell-back analysis)
5. Calibration statistics (AUC, Brier, decile calibration)
"""

import pandas as pd
import numpy as np
from scipy import stats
import statsmodels.api as sm
from statsmodels.duration.hazard_regression import PHReg
from sklearn.metrics import roc_auc_score, brier_score_loss
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()

try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    panel['resource_rents_gdp'] = np.nan

cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')
LOWER = 9000
UPPER = 25000

zone = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()

print(f"Zone entrants: {len(zone)}")
print(f"  Crossed: {zone['exited_above'].sum()}")
print(f"  Did not cross: {(zone['exited_above'] == 0).sum()}")

# Collection for output table
revision_results = []

# ═══════════════════════════════════════════════════════════════════════
# 1. DISTANCE-TO-EXIT CONTROL
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("1. DISTANCE-TO-EXIT CONTROL")
print("=" * 70)
print("Referee concern: countries entering closer to $25k mechanically cross.")

# Create distance-to-exit variables
zone['log_dist_to_exit'] = np.log(UPPER - zone['gdp_pc_ppp_entry'])
zone['entry_pctile'] = zone['gdp_pc_ppp_entry'].rank(pct=True)

# Baseline: Z₁ only (reproduce core result)
def run_logit_report(df, y_var, x_vars, label):
    """Run logit, print, return model and key stats."""
    sample = df[[y_var] + x_vars].dropna()
    if len(sample) < 15:
        print(f"  {label}: insufficient obs ({len(sample)})")
        return None, {}
    y = sample[y_var]
    X = sm.add_constant(sample[x_vars])
    try:
        model = sm.Logit(y, X).fit(disp=0)
        print(f"\n  {label} (n={len(sample)}, pseudo-R²={model.prsquared:.3f}, AIC={model.aic:.1f}):")
        for var in x_vars:
            coef = model.params[var]
            se = model.bse[var]
            pval = model.pvalues[var]
            sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else ''
            print(f"    {var:30s}  coef={coef:8.4f}  se={se:.4f}  p={pval:.4f}{sig}")
        return model, {
            'label': label, 'n': len(sample), 'pseudo_r2': model.prsquared,
            'aic': model.aic,
            'Z1_coef': model.params.get('Z_1_entry', np.nan),
            'Z1_pval': model.pvalues.get('Z_1_entry', np.nan),
        }
    except Exception as e:
        print(f"  {label}: failed — {e}")
        return None, {}

# 1a. Baseline: Z₁ only
m_base, r_base = run_logit_report(zone, 'exited_above', ['Z_1_entry'], 'Baseline: Z₁ only')

# 1b. Z₁ + log_dist_to_exit
m_dist, r_dist = run_logit_report(zone, 'exited_above',
    ['Z_1_entry', 'log_dist_to_exit'], 'Z₁ + log(distance to exit)')

# 1c. Z₁ + entry_pctile
m_pctile, r_pctile = run_logit_report(zone, 'exited_above',
    ['Z_1_entry', 'entry_pctile'], 'Z₁ + entry percentile')

# 1d. Z₁ + raw GDP at entry
m_gdp, r_gdp = run_logit_report(zone, 'exited_above',
    ['Z_1_entry', 'gdp_pc_ppp_entry'], 'Z₁ + GDP/cap at entry')

# Summary
print("\n  SUMMARY — Distance-to-exit controls:")
print(f"  {'Specification':35s} {'Z₁ coef':>10s} {'Z₁ p-val':>10s} {'pseudo-R²':>10s}")
print("  " + "-" * 68)
for label, r in [('Baseline: Z₁ only', r_base),
                  ('+ log(dist to exit)', r_dist),
                  ('+ entry percentile', r_pctile),
                  ('+ GDP/cap at entry', r_gdp)]:
    if r:
        z1c = r.get('Z1_coef', np.nan)
        z1p = r.get('Z1_pval', np.nan)
        pr2 = r.get('pseudo_r2', np.nan)
        sig = '***' if z1p < 0.01 else '**' if z1p < 0.05 else '*' if z1p < 0.1 else ''
        print(f"  {label:35s} {z1c:10.4f}{sig:3s} {z1p:10.4f} {pr2:10.3f}")
        revision_results.append({
            'analysis': 'Distance-to-exit',
            'specification': label,
            'Z1_coef': z1c, 'Z1_pval': z1p,
            'pseudo_r2': pr2, 'n': r['n'],
        })

# ═══════════════════════════════════════════════════════════════════════
# 2. SAVINGS → CROSSING BRIDGING REGRESSION
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2. SAVINGS → CROSSING BRIDGING REGRESSION")
print("=" * 70)
print("Does savings rate at/near entry predict crossing directly?")

# Compute average savings in first 5 years in zone
for idx, row in zone.iterrows():
    iso = row['iso3']
    entry = row['cross_9k']
    if pd.isna(entry):
        continue
    c = panel[(panel['iso3'] == iso) & (panel['year'] >= entry) & (panel['year'] <= entry + 5)]
    zone.loc[idx, 'savings_first5'] = c['gross_savings_gdp'].mean() if len(c[c['gross_savings_gdp'].notna()]) > 0 else np.nan
    zone.loc[idx, 'savings_at_entry'] = row.get('gross_savings_gdp_entry', np.nan)

# 2a. Savings only → crossing
print("\n  2a. Savings (first 5 years avg) → P(crossing)")
m_sav_only, r_sav_only = run_logit_report(zone, 'exited_above',
    ['savings_first5'], 'Savings (5yr avg) only')

# 2b. Savings at entry only
print("\n  2b. Savings at entry → P(crossing)")
m_sav_entry, r_sav_entry = run_logit_report(zone, 'exited_above',
    ['savings_at_entry'], 'Savings at entry only')

# 2c. Z₁ + savings → does Z₁ attenuate?
print("\n  2c. Z₁ + savings (first 5yr avg) — attenuation test")
m_z1_sav, r_z1_sav = run_logit_report(zone, 'exited_above',
    ['Z_1_entry', 'savings_first5'], 'Z₁ + savings (5yr avg)')

# 2d. Z₁ + savings at entry
print("\n  2d. Z₁ + savings at entry — attenuation test")
m_z1_sav_e, r_z1_sav_e = run_logit_report(zone, 'exited_above',
    ['Z_1_entry', 'savings_at_entry'], 'Z₁ + savings at entry')

# Attenuation summary
print("\n  ATTENUATION SUMMARY:")
if r_base and r_z1_sav:
    atten_5yr = (1 - r_z1_sav.get('Z1_coef', 0) / r_base.get('Z1_coef', 1)) * 100
    print(f"  Z₁ attenuation with 5yr savings control: {atten_5yr:.1f}%")
    revision_results.append({
        'analysis': 'Savings bridging',
        'specification': 'Z₁ attenuation with savings (5yr)',
        'Z1_coef': r_z1_sav.get('Z1_coef'),
        'Z1_pval': r_z1_sav.get('Z1_pval'),
        'pseudo_r2': r_z1_sav.get('pseudo_r2'),
        'n': r_z1_sav.get('n'),
        'note': f'Attenuation: {atten_5yr:.1f}%',
    })
if r_base and r_z1_sav_e:
    atten_entry = (1 - r_z1_sav_e.get('Z1_coef', 0) / r_base.get('Z1_coef', 1)) * 100
    print(f"  Z₁ attenuation with entry savings control: {atten_entry:.1f}%")
    revision_results.append({
        'analysis': 'Savings bridging',
        'specification': 'Z₁ attenuation with savings (entry)',
        'Z1_coef': r_z1_sav_e.get('Z1_coef'),
        'Z1_pval': r_z1_sav_e.get('Z1_pval'),
        'pseudo_r2': r_z1_sav_e.get('pseudo_r2'),
        'n': r_z1_sav_e.get('n'),
        'note': f'Attenuation: {atten_entry:.1f}%',
    })

# ═══════════════════════════════════════════════════════════════════════
# 3. SCHOENFELD RESIDUAL TEST FOR COX PH ASSUMPTION
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3. SCHOENFELD RESIDUAL TEST FOR COX PH ASSUMPTION")
print("=" * 70)

# Build survival data
surv = zone[['iso3', 'exited_above', 'years_in_zone', 'Z_1_entry', 'cross_9k']].copy()
# For censored observations, compute time-in-zone as 2024 - entry year
surv['duration'] = surv['years_in_zone'].copy()
surv.loc[surv['duration'].isna(), 'duration'] = 2024 - surv.loc[surv['duration'].isna(), 'cross_9k']
surv = surv[surv['duration'].notna() & (surv['duration'] > 0) & surv['Z_1_entry'].notna()].copy()

print(f"\nSurvival sample: n={len(surv)}, events={surv['exited_above'].sum()}, "
      f"censored={(surv['exited_above']==0).sum()}")

# Fit Cox PH using statsmodels PHReg
try:
    exog = surv[['Z_1_entry']].values
    time = surv['duration'].values
    event = surv['exited_above'].values.astype(int)

    cox = PHReg(time, exog, status=event)
    cox_result = cox.fit()

    print(f"\n  Cox PH model:")
    print(f"  Z₁ coef (log-HR): {cox_result.params[0]:.4f}")
    print(f"  Z₁ SE: {cox_result.bse[0]:.4f}")
    print(f"  Z₁ p-value: {cox_result.pvalues[0]:.4f}")
    hr = np.exp(cox_result.params[0])
    print(f"  Hazard ratio: {hr:.3f}")

    # Manual Schoenfeld residuals (more reliable than statsmodels attribute)
    beta = cox_result.params[0]
    event_mask = event == 1
    event_times_all = time[event_mask]
    z1_vals = exog[:, 0]
    z1_events = z1_vals[event_mask]

    # Sort events by time
    sort_idx = np.argsort(event_times_all)
    event_times_sorted = event_times_all[sort_idx]
    z1_events_sorted = z1_events[sort_idx]

    schoen_resids = []
    for i, t_i in enumerate(event_times_sorted):
        # Risk set: all subjects still at risk at time t_i
        at_risk = (time >= t_i)
        z_risk = z1_vals[at_risk]
        # Weighted mean of covariate in risk set
        exp_bz = np.exp(beta * z_risk)
        weighted_mean = np.sum(z_risk * exp_bz) / np.sum(exp_bz)
        schoen_resids.append(z1_events_sorted[i] - weighted_mean)

    schoen_resids = np.array(schoen_resids)

    rho, p_schoenfeld = stats.spearmanr(event_times_sorted, schoen_resids)
    print(f"\n  Schoenfeld residual test (PH assumption):")
    print(f"  n events: {len(schoen_resids)}")
    print(f"  Spearman rho(time, residual): {rho:.4f}")
    print(f"  p-value: {p_schoenfeld:.4f}")

    rho_p, p_pearson = stats.pearsonr(event_times_sorted, schoen_resids)
    print(f"  Pearson r(time, residual): {rho_p:.4f}")
    print(f"  p-value (Pearson): {p_pearson:.4f}")

    # Also test with log(time)
    rho_log, p_log = stats.spearmanr(np.log(event_times_sorted), schoen_resids)
    print(f"  Spearman rho(log(time), residual): {rho_log:.4f}, p={p_log:.4f}")

    if p_schoenfeld < 0.05:
        print("  *** PH ASSUMPTION VIOLATED — Z₁ effect varies over time")
        ph_verdict = 'VIOLATED (p<0.05)'
    else:
        print("  PH assumption NOT rejected — proportional hazards holds")
        ph_verdict = f'Not rejected (p={p_schoenfeld:.3f})'

    revision_results.append({
        'analysis': 'Schoenfeld PH test',
        'specification': 'Cox PH: Z₁ → crossing',
        'Z1_coef': cox_result.params[0],
        'Z1_pval': cox_result.pvalues[0],
        'n': len(surv),
        'note': f'HR={hr:.3f}, Schoenfeld rho={rho:.4f}, p={p_schoenfeld:.4f} — {ph_verdict}',
    })

except Exception as e:
    print(f"  Cox model failed: {e}")
    import traceback
    traceback.print_exc()


# ═══════════════════════════════════════════════════════════════════════
# 4. COMPETING RISK: FELL-BACK JUSTIFICATION
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("4. COMPETING RISK / FELL-BACK JUSTIFICATION")
print("=" * 70)

# Identify fell-back countries
fell_back_statuses = ['Entered zone then fell back', 'Fell back below']
fell_back_countries = zone[zone['status'].isin(fell_back_statuses)].copy()
print(f"\nFell-back countries: {len(fell_back_countries)}")
for _, row in fell_back_countries.iterrows():
    print(f"  {row['iso3']}: {row['status']}, Z₁ entry={row.get('Z_1_entry', np.nan):.3f}, "
          f"GDP entry=${row.get('gdp_pc_ppp_entry', np.nan):,.0f}")

# Also check phase7 table17 for broader fell-back list
try:
    fb_table = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/output/tables/table17_fell_back.csv')
    fb_isos = fb_table['iso3'].tolist()
    print(f"\nBroader fell-back list (from phase7): {len(fb_table)} countries")
    conflict_states = ['SYR', 'IRQ', 'LBY', 'YEM', 'UKR', 'AFG']
    resource_collapse = ['VEN', 'GAB', 'COG', 'AGO', 'GNQ', 'TKM']
    for _, row in fb_table.iterrows():
        iso = row['iso3']
        if iso in conflict_states:
            reason = "CONFLICT"
        elif iso in resource_collapse or (pd.notna(row.get('resource_rents_avg')) and row['resource_rents_avg'] > 10):
            reason = "RESOURCE COLLAPSE"
        else:
            reason = "POST-SOVIET TRANSITION"
        print(f"  {iso}: {reason} (decline {row['decline_pct']:.0f}%)")
except:
    fb_isos = fell_back_countries['iso3'].tolist()

# 4a. Cox model excluding fell-back countries
print("\n  4a. Cox model EXCLUDING fell-back countries")
surv_no_fb = surv[~surv['iso3'].isin(fb_isos)].copy()
print(f"  Sample: n={len(surv_no_fb)}, events={surv_no_fb['exited_above'].sum()}")

try:
    cox_nfb = PHReg(surv_no_fb['duration'].values,
                    surv_no_fb[['Z_1_entry']].values,
                    status=surv_no_fb['exited_above'].values.astype(int))
    cox_nfb_result = cox_nfb.fit()
    hr_nfb = np.exp(cox_nfb_result.params[0])
    sig_nfb = '***' if cox_nfb_result.pvalues[0] < 0.01 else '**' if cox_nfb_result.pvalues[0] < 0.05 else '*' if cox_nfb_result.pvalues[0] < 0.1 else ''
    print(f"  Z₁ coef: {cox_nfb_result.params[0]:.4f}, HR={hr_nfb:.3f}, p={cox_nfb_result.pvalues[0]:.4f}{sig_nfb}")
    print(f"  Full sample HR: {hr:.3f} → Excluding fell-back HR: {hr_nfb:.3f}")

    revision_results.append({
        'analysis': 'Competing risk',
        'specification': 'Cox excl. fell-back countries',
        'Z1_coef': cox_nfb_result.params[0],
        'Z1_pval': cox_nfb_result.pvalues[0],
        'n': len(surv_no_fb),
        'note': f'HR={hr_nfb:.3f} (full sample HR={hr:.3f})',
    })
except Exception as e:
    print(f"  Failed: {e}")

# 4b. Logit excluding fell-back countries
print("\n  4b. Logit EXCLUDING fell-back countries")
zone_no_fb = zone[~zone['iso3'].isin(fb_isos)].copy()
m_nfb, r_nfb = run_logit_report(zone_no_fb, 'exited_above',
    ['Z_1_entry'], 'Z₁ logit excl. fell-back')
if r_nfb and r_base:
    print(f"\n  Full sample Z₁ coef: {r_base['Z1_coef']:.4f}")
    print(f"  Excl. fell-back Z₁ coef: {r_nfb['Z1_coef']:.4f}")
    revision_results.append({
        'analysis': 'Competing risk',
        'specification': 'Logit excl. fell-back',
        'Z1_coef': r_nfb['Z1_coef'],
        'Z1_pval': r_nfb['Z1_pval'],
        'pseudo_r2': r_nfb['pseudo_r2'],
        'n': r_nfb['n'],
        'note': f'Full sample coef={r_base["Z1_coef"]:.4f}',
    })

# 4c. Demographic characteristics of fell-back vs stuck-in-zone
print("\n  4c. Are fell-back countries demographically distinct from stuck-in-zone?")
stuck = zone[zone['status'] == 'Stuck in zone']
if len(fell_back_countries) >= 2 and len(stuck) >= 2:
    fb_z1 = fell_back_countries['Z_1_entry'].dropna()
    stuck_z1 = stuck['Z_1_entry'].dropna()
    crosser_z1 = zone[zone['exited_above'] == 1]['Z_1_entry'].dropna()
    print(f"  Fell-back Z₁ at entry: mean={fb_z1.mean():.3f} (n={len(fb_z1)})")
    print(f"  Stuck-in-zone Z₁ at entry: mean={stuck_z1.mean():.3f} (n={len(stuck_z1)})")
    print(f"  Crossers Z₁ at entry: mean={crosser_z1.mean():.3f} (n={len(crosser_z1)})")
    if len(fb_z1) >= 2 and len(stuck_z1) >= 2:
        t, p = stats.ttest_ind(fb_z1, stuck_z1, equal_var=False)
        print(f"  Fell-back vs stuck t-test: p={p:.4f}")
    print("  Conclusion: Fell-back countries are geopolitical/resource shocks,")
    print("  not demographic failures — justifies treating as censored.")


# ═══════════════════════════════════════════════════════════════════════
# 5. CALIBRATION STATISTICS
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5. CALIBRATION STATISTICS (AUC, BRIER, DECILE CALIBRATION)")
print("=" * 70)

# Use baseline logit (Z₁ only)
if m_base is not None:
    calib_sample = zone[['exited_above', 'Z_1_entry']].dropna()
    y_true = calib_sample['exited_above'].values
    X_calib = sm.add_constant(calib_sample[['Z_1_entry']])
    y_pred = m_base.predict(X_calib)

    # 5a. AUC
    auc = roc_auc_score(y_true, y_pred)
    print(f"\n  5a. AUC (ROC): {auc:.4f}")

    # 5b. Brier score
    brier = brier_score_loss(y_true, y_pred)
    # Baseline Brier (predict mean)
    base_rate = y_true.mean()
    brier_baseline = np.mean((y_true - base_rate) ** 2)
    brier_skill = 1 - brier / brier_baseline
    print(f"  5b. Brier score: {brier:.4f}")
    print(f"      Baseline Brier (predict mean): {brier_baseline:.4f}")
    print(f"      Brier skill score: {brier_skill:.4f}")

    # 5c. Calibration by decile
    print(f"\n  5c. Calibration by predicted probability decile:")
    calib_df = pd.DataFrame({'y_true': y_true, 'y_pred': y_pred})
    # Use quintiles since n≈87
    n_bins = 5
    calib_df['bin'] = pd.qcut(calib_df['y_pred'], n_bins, labels=False, duplicates='drop')

    print(f"  {'Bin':>4s} {'Pred range':>20s} {'Avg pred':>10s} {'Obs rate':>10s} {'n':>5s} {'n_cross':>8s}")
    print("  " + "-" * 60)
    calib_rows = []
    for b in sorted(calib_df['bin'].unique()):
        sub = calib_df[calib_df['bin'] == b]
        pred_lo = sub['y_pred'].min()
        pred_hi = sub['y_pred'].max()
        avg_pred = sub['y_pred'].mean()
        obs_rate = sub['y_true'].mean()
        n_bin = len(sub)
        n_cross = sub['y_true'].sum()
        print(f"  {b:4d} [{pred_lo:.3f}, {pred_hi:.3f}] {avg_pred:10.3f} {obs_rate:10.3f} {n_bin:5d} {n_cross:8.0f}")
        calib_rows.append({
            'bin': b, 'pred_lo': pred_lo, 'pred_hi': pred_hi,
            'avg_predicted': avg_pred, 'observed_rate': obs_rate,
            'n': n_bin, 'n_crossed': int(n_cross),
        })

    # Hosmer-Lemeshow test
    hl_chi2 = 0
    for row in calib_rows:
        if row['n'] > 0:
            e = row['avg_predicted'] * row['n']
            o = row['n_crossed']
            if e > 0 and (row['n'] - e) > 0:
                hl_chi2 += (o - e) ** 2 / (e * (1 - row['avg_predicted']))
    hl_df = len(calib_rows) - 2
    hl_p = 1 - stats.chi2.cdf(hl_chi2, hl_df) if hl_df > 0 else np.nan
    print(f"\n  Hosmer-Lemeshow chi²={hl_chi2:.3f}, df={hl_df}, p={hl_p:.4f}")
    if hl_p > 0.05:
        print("  Calibration is adequate (fail to reject H₀ of good fit)")
    else:
        print("  Calibration may be poor (reject H₀)")

    revision_results.append({
        'analysis': 'Calibration',
        'specification': 'Z₁-only logit',
        'n': len(calib_sample),
        'note': f'AUC={auc:.4f}, Brier={brier:.4f}, Brier skill={brier_skill:.4f}, HL p={hl_p:.4f}',
    })

    # 5d. Key prediction: China
    print("\n  5d. China prediction check:")
    chn = panel[panel['iso3'] == 'CHN'].sort_values('year')
    chn_entry = chn[chn['gdp_pc_ppp'] >= LOWER]['year'].min()
    if pd.notna(chn_entry):
        chn_z1 = chn[chn['year'] == chn_entry]['Z_1'].values
        if len(chn_z1) > 0:
            chn_z1_val = chn_z1[0]
            chn_pred = m_base.predict(np.array([[1.0, chn_z1_val]]))[0]
            print(f"  China Z₁ at zone entry ({int(chn_entry)}): {chn_z1_val:.3f}")
            print(f"  Model-predicted P(crossing): {chn_pred:.4f}")
            print(f"  This is based on the Z₁-only logit (the core model).")

            # Context: what percentile of predictions is this?
            pctile = (y_pred < chn_pred).mean() * 100
            print(f"  Percentile among zone entrants: {pctile:.1f}th")

else:
    print("  Baseline model not available — skipping calibration.")


# ═══════════════════════════════════════════════════════════════════════
# SAVE OUTPUT TABLE
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("OUTPUT TABLE")
print("=" * 70)

results_df = pd.DataFrame(revision_results)
outpath = '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table18_revisions.csv'
results_df.to_csv(outpath, index=False)
print(f"\nSaved to: {outpath}")
print(results_df.to_string(index=False))

print("\n" + "=" * 70)
print("PHASE 8 COMPLETE")
print("=" * 70)
